#
# Copyright (c) Lightly AG and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from __future__ import annotations

import dataclasses
import logging
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

import torch
from torch import Tensor
from torch.nn import Module
from torch.serialization import MAP_LOCATION

import lightly_train
from lightly_train._models.embedding_model import EmbeddingModel
from lightly_train._models.model_wrapper import ModelWrapper
from lightly_train._transforms.transform import NormalizeArgs

logger = logging.getLogger(__name__)

CHECKPOINT_LIGHTLY_TRAIN_KEY = "lightly_train"


@dataclass(frozen=True)
class CheckpointLightlyTrain:
    version: str
    date: datetime
    models: CheckpointLightlyTrainModels
    normalize_args: NormalizeArgs

    def to_dict(self) -> dict[str, Any]:
        d = dataclasses.asdict(self)
        d["date"] = d["date"].isoformat()
        d["models"] = self.models.to_dict()
        d["normalize_args"] = self.normalize_args.to_dict()
        return d

    @staticmethod
    def from_dict(checkpoint_info: dict[str, Any]) -> CheckpointLightlyTrain:
        return CheckpointLightlyTrain(
            version=checkpoint_info["version"],
            date=datetime.fromisoformat(checkpoint_info["date"]),
            models=CheckpointLightlyTrainModels.from_dict(checkpoint_info["models"]),
            normalize_args=NormalizeArgs.from_dict(checkpoint_info["normalize_args"]),
        )

    @staticmethod
    def from_now(
        models: CheckpointLightlyTrainModels, normalize_args: NormalizeArgs
    ) -> CheckpointLightlyTrain:
        return CheckpointLightlyTrain(
            version=lightly_train.__version__,
            date=datetime.now(timezone.utc).astimezone(),
            models=models,
            normalize_args=normalize_args,
        )

    @staticmethod
    def from_checkpoint_dict(checkpoint: dict[str, Any]) -> CheckpointLightlyTrain:
        return CheckpointLightlyTrain.from_dict(
            checkpoint[CHECKPOINT_LIGHTLY_TRAIN_KEY]
        )


@dataclass(frozen=True)
class CheckpointLightlyTrainModels:
    model: Module
    wrapped_model: ModelWrapper
    embedding_model: EmbeddingModel

    def to_dict(self) -> dict[str, Any]:
        return {
            "model": self.model,
            "wrapped_model": self.wrapped_model,
            "embedding_model": self.embedding_model,
        }

    @staticmethod
    def from_dict(models: dict[str, Any]) -> CheckpointLightlyTrainModels:
        return CheckpointLightlyTrainModels(
            model=models["model"],
            wrapped_model=models["wrapped_model"],
            embedding_model=models["embedding_model"],
        )


@dataclass(frozen=True)
class Checkpoint:
    """Checkpoint as generated by PyTorch Lightning Trainer with the ModelCheckpoint
    callback from lightly_train.callbacks.checkpoint.
    """

    state_dict: dict[str, Tensor]
    lightly_train: CheckpointLightlyTrain

    def to_dict(self) -> dict[str, Any]:
        return {
            "state_dict": self.state_dict,
            CHECKPOINT_LIGHTLY_TRAIN_KEY: self.lightly_train.to_dict(),
        }

    @staticmethod
    def from_dict(checkpoint: dict[str, Any]) -> Checkpoint:
        return Checkpoint(
            state_dict=checkpoint["state_dict"],
            lightly_train=CheckpointLightlyTrain.from_checkpoint_dict(
                checkpoint=checkpoint
            ),
        )

    @staticmethod
    def from_path(
        checkpoint: Path,
        map_location: MAP_LOCATION | None = "cpu",
        weights_only: bool = False,
    ) -> Checkpoint:
        """Load a checkpoint from a file path.

        Args:
            checkpoint:
                Path to the checkpoint file.
            map_location:
                If map_location is a string, it must be a key in torch.device, such as
                'cpu' or 'cuda:0'. If map_location is a torch.device, it will be used to
                determine where the checkpoint should be loaded to. Default: 'cpu'.
            weights_only:
                If False (default), the whole checkpoint is loaded. If True, only the weights
                of the model are loaded. This requires the user to add safe globals with:
                https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
                TODO(Philipp, 09/24): Expose weights_only argument to the user.
        """
        logger.debug(
            f"Loading checkpoint from '{checkpoint}' with map_location '{map_location}' and weights_only {weights_only}"
        )
        checkpoint_dict = torch.load(
            checkpoint, map_location=map_location, weights_only=weights_only
        )
        return Checkpoint.from_dict(checkpoint=checkpoint_dict)

    def save(self, path: Path) -> None:
        logger.debug(f"Saving checkpoint to '{path}'")
        path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(self.to_dict(), path)
